/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search.join;

import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Locale;
import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.OrdinalMap;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.internal.hppc.LongCursor;
import org.apache.lucene.internal.hppc.LongFloatHashMap;
import org.apache.lucene.internal.hppc.LongHashSet;
import org.apache.lucene.internal.hppc.LongIntHashMap;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointInSetQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.join.DocValuesTermsCollector.Function;
import org.apache.lucene.util.BytesRef;

/**
 * Utility for query time joining.
 *
 * @lucene.experimental
 */
public final class JoinUtil {

  // No instances allowed
  private JoinUtil() {}

  /**
   * Method for query time joining.
   *
   * <p>Execute the returned query with a {@link IndexSearcher} to retrieve all documents that have
   * the same terms in the to field that match with documents matching the specified fromQuery and
   * have the same terms in the from field.
   *
   * <p>In the case a single document relates to more than one document the <code>
   * multipleValuesPerDocument</code> option should be set to true. When the <code>
   * multipleValuesPerDocument</code> is set to <code>true</code> only the score from the first
   * encountered join value originating from the 'from' side is mapped into the 'to' side. Even in
   * the case when a second join value related to a specific document yields a higher score.
   * Obviously this doesn't apply in the case that {@link ScoreMode#None} is used, since no scores
   * are computed at all.
   *
   * <p>Memory considerations: During joining all unique join values are kept in memory. On top of
   * that when the scoreMode isn't set to {@link ScoreMode#None} a float value per unique join value
   * is kept in memory for computing scores. When scoreMode is set to {@link ScoreMode#Avg} also an
   * additional integer value is kept in memory per unique join value.
   *
   * @param fromField The from field to join from
   * @param multipleValuesPerDocument Whether the from field has multiple terms per document
   * @param toField The to field to join to
   * @param fromQuery The query to match documents on the from side
   * @param fromSearcher The searcher that executed the specified fromQuery
   * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
   * @return a {@link Query} instance that can be used to join documents based on the terms in the
   *     from and to field
   * @throws IOException If I/O related errors occur
   */
  public static Query createJoinQuery(
      String fromField,
      boolean multipleValuesPerDocument,
      String toField,
      Query fromQuery,
      IndexSearcher fromSearcher,
      ScoreMode scoreMode)
      throws IOException {

    final GenericTermsCollector termsWithScoreCollector;

    if (multipleValuesPerDocument) {
      Function<SortedSetDocValues> mvFunction =
          DocValuesTermsCollector.sortedSetDocValues(fromField);
      termsWithScoreCollector = GenericTermsCollector.createCollectorMV(mvFunction, scoreMode);
    } else {
      Function<SortedDocValues> svFunction = DocValuesTermsCollector.sortedDocValues(fromField);
      termsWithScoreCollector = GenericTermsCollector.createCollectorSV(svFunction, scoreMode);
    }

    return createJoinQuery(
        multipleValuesPerDocument,
        toField,
        fromQuery,
        fromField,
        fromSearcher,
        scoreMode,
        termsWithScoreCollector);
  }

  /**
   * Method for query time joining for numeric fields. It supports multi- and single- values longs,
   * ints, floats and longs. All considerations from {@link JoinUtil#createJoinQuery(String,
   * boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, though memory
   * consumption might be higher.
   *
   * @param fromField The from field to join from
   * @param multipleValuesPerDocument Whether the from field has multiple terms per document when
   *     true fromField might be {@link DocValuesType#SORTED_NUMERIC}, otherwise fromField should be
   *     {@link DocValuesType#NUMERIC}
   * @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link
   *     FloatPoint} or {@link DoublePoint}.
   * @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link
   *     java.lang.Float} or {@link java.lang.Double} it should correspond to toField types
   * @param fromQuery The query to match documents on the from side
   * @param fromSearcher The searcher that executed the specified fromQuery
   * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
   * @return a {@link Query} instance that can be used to join documents based on the terms in the
   *     from and to field
   * @throws IOException If I/O related errors occur
   */
  public static Query createJoinQuery(
      String fromField,
      boolean multipleValuesPerDocument,
      String toField,
      Class<? extends Number> numericType,
      Query fromQuery,
      IndexSearcher fromSearcher,
      ScoreMode scoreMode)
      throws IOException {
    LongHashSet joinValues = new LongHashSet();
    LongFloatHashMap aggregatedScores = new LongFloatHashMap();
    LongIntHashMap occurrences = new LongIntHashMap();
    boolean needsScore = scoreMode != ScoreMode.None;
    LongFloatProcedure scoreAggregator;
    if (scoreMode == ScoreMode.Max) {
      scoreAggregator =
          (key, score) -> {
            int index = aggregatedScores.indexOf(key);
            if (index < 0) {
              aggregatedScores.indexInsert(index, key, score);
            } else {
              float currentScore = aggregatedScores.indexGet(index);
              aggregatedScores.indexReplace(index, Math.max(currentScore, score));
            }
          };
    } else if (scoreMode == ScoreMode.Min) {
      scoreAggregator =
          (key, score) -> {
            int index = aggregatedScores.indexOf(key);
            if (index < 0) {
              aggregatedScores.indexInsert(index, key, score);
            } else {
              float currentScore = aggregatedScores.indexGet(index);
              aggregatedScores.indexReplace(index, Math.min(currentScore, score));
            }
          };
    } else if (scoreMode == ScoreMode.Total) {
      scoreAggregator = aggregatedScores::addTo;
    } else if (scoreMode == ScoreMode.Avg) {
      scoreAggregator =
          (key, score) -> {
            aggregatedScores.addTo(key, score);
            occurrences.addTo(key, 1);
          };
    } else {
      scoreAggregator =
          (_, _) -> {
            throw new UnsupportedOperationException();
          };
    }

    LongFloatFunction joinScorer;
    if (scoreMode == ScoreMode.Avg) {
      joinScorer =
          (joinValue) -> {
            float aggregatedScore = aggregatedScores.get(joinValue);
            int occurrence = occurrences.get(joinValue);
            return aggregatedScore / occurrence;
          };
    } else {
      joinScorer = aggregatedScores::get;
    }

    Collector collector;
    if (multipleValuesPerDocument) {
      collector =
          new SimpleCollector() {

            SortedNumericDocValues sortedNumericDocValues;
            Scorable scorer;

            @Override
            public void collect(int doc) throws IOException {
              if (sortedNumericDocValues.advanceExact(doc)) {
                for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) {
                  long value = sortedNumericDocValues.nextValue();
                  joinValues.add(value);
                  if (needsScore) {
                    scoreAggregator.apply(value, scorer.score());
                  }
                }
              }
            }

            @Override
            protected void doSetNextReader(LeafReaderContext context) throws IOException {
              sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField);
            }

            @Override
            public void setScorer(Scorable scorer) throws IOException {
              this.scorer = scorer;
            }

            @Override
            public org.apache.lucene.search.ScoreMode scoreMode() {
              return needsScore
                  ? org.apache.lucene.search.ScoreMode.COMPLETE
                  : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
            }
          };
    } else {
      collector =
          new SimpleCollector() {

            NumericDocValues numericDocValues;
            Scorable scorer;
            private int lastDocID = -1;

            private boolean docsInOrder(int docID) {
              if (docID < lastDocID) {
                throw new AssertionError(
                    "docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID);
              }
              lastDocID = docID;
              return true;
            }

            @Override
            public void collect(int doc) throws IOException {
              assert docsInOrder(doc);
              long value = 0;
              if (numericDocValues.advanceExact(doc)) {
                value = numericDocValues.longValue();
              }
              joinValues.add(value);
              if (needsScore) {
                scoreAggregator.apply(value, scorer.score());
              }
            }

            @Override
            protected void doSetNextReader(LeafReaderContext context) throws IOException {
              numericDocValues = DocValues.getNumeric(context.reader(), fromField);
              lastDocID = -1;
            }

            @Override
            public void setScorer(Scorable scorer) throws IOException {
              this.scorer = scorer;
            }

            @Override
            public org.apache.lucene.search.ScoreMode scoreMode() {
              return needsScore
                  ? org.apache.lucene.search.ScoreMode.COMPLETE
                  : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
            }
          };
    }
    fromSearcher.search(fromQuery, collector);

    LongArrayList joinValuesList = new LongArrayList(joinValues.size());
    joinValuesList.addAll(joinValues);
    Arrays.sort(joinValuesList.buffer, 0, joinValuesList.size());
    Iterator<LongCursor> iterator = joinValuesList.iterator();

    final int bytesPerDim;
    final BytesRef encoded = new BytesRef();
    final PointInSetIncludingScoreQuery.Stream stream;
    if (Integer.class.equals(numericType)) {
      bytesPerDim = Integer.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                IntPoint.encodeDimension((int) value.value, encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else if (Long.class.equals(numericType)) {
      bytesPerDim = Long.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                LongPoint.encodeDimension(value.value, encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else if (Float.class.equals(numericType)) {
      bytesPerDim = Float.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                FloatPoint.encodeDimension(
                    Float.intBitsToFloat((int) value.value), encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else if (Double.class.equals(numericType)) {
      bytesPerDim = Double.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                DoublePoint.encodeDimension(Double.longBitsToDouble(value.value), encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else {
      throw new IllegalArgumentException(
          "unsupported numeric type, only Integer, Long, Float and Double are supported");
    }

    encoded.bytes = new byte[bytesPerDim];
    encoded.length = bytesPerDim;

    if (needsScore) {
      return new PointInSetIncludingScoreQuery(
          scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) {

        @Override
        protected String toString(byte[] value) {
          return toString.apply(value, numericType);
        }
      };
    } else {
      return new PointInSetQuery(toField, 1, bytesPerDim, stream) {
        @Override
        protected String toString(byte[] value) {
          return PointInSetIncludingScoreQuery.toString.apply(value, numericType);
        }
      };
    }
  }

  private static Query createJoinQuery(
      boolean multipleValuesPerDocument,
      String toField,
      Query fromQuery,
      String fromField,
      IndexSearcher fromSearcher,
      ScoreMode scoreMode,
      final GenericTermsCollector collector)
      throws IOException {

    fromSearcher.search(fromQuery, collector);
    switch (scoreMode) {
      case None:
        return new TermsQuery(
            toField,
            collector.getCollectedTerms(),
            fromField,
            fromQuery,
            fromSearcher.getTopReaderContext().id());
      case Total:
      case Max:
      case Min:
      case Avg:
        return new TermsIncludingScoreQuery(
            scoreMode,
            toField,
            multipleValuesPerDocument,
            collector.getCollectedTerms(),
            collector.getScoresPerTerm(),
            fromField,
            fromQuery,
            fromSearcher.getTopReaderContext().id());
      default:
        throw new IllegalArgumentException(
            String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode));
    }
  }

  /**
   * Delegates to {@link #createJoinQuery(String, Query, Query, IndexSearcher, ScoreMode,
   * OrdinalMap, int, int)}, but disables the min and max filtering.
   *
   * @param joinField The {@link SortedDocValues} field containing the join values
   * @param fromQuery The query containing the actual user query. Also the fromQuery can only match
   *     "from" documents.
   * @param toQuery The query identifying all documents on the "to" side.
   * @param searcher The index searcher used to execute the from query
   * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
   * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment
   *     index, no ordinal map needs to be provided.
   * @return a {@link Query} instance that can be used to join documents based on the join field
   * @throws IOException If I/O related errors occur
   */
  public static Query createJoinQuery(
      String joinField,
      Query fromQuery,
      Query toQuery,
      IndexSearcher searcher,
      ScoreMode scoreMode,
      OrdinalMap ordinalMap)
      throws IOException {
    return createJoinQuery(
        joinField, fromQuery, toQuery, searcher, scoreMode, ordinalMap, 0, Integer.MAX_VALUE);
  }

  /**
   * A query time join using global ordinals over a dedicated join field.
   *
   * <p>This join has certain restrictions and requirements: 1) A document can only refer to one
   * other document. (but can be referred by one or more documents) 2) Documents on each side of the
   * join must be distinguishable. Typically this can be done by adding an extra field that
   * identifies the "from" and "to" side and then the fromQuery and toQuery must take this into
   * account. 3) There must be a single sorted doc values join field used by both the "from" and
   * "to" documents. This join field should store the join values as UTF-8 strings. 4) An ordinal
   * map must be provided that is created on top of the join field.
   *
   * <p>Note: min and max filtering and the avg score mode will require this join to keep track of
   * the number of times a document matches per join value. This will increase the per join cost in
   * terms of execution time and memory.
   *
   * @param joinField The {@link SortedDocValues} field containing the join values
   * @param fromQuery The query containing the actual user query. Also the fromQuery can only match
   *     "from" documents.
   * @param toQuery The query identifying all documents on the "to" side.
   * @param searcher The index searcher used to execute the from query
   * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
   * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment
   *     index, no ordinal map needs to be provided.
   * @param min Optionally the minimum number of "from" documents that are required to match for a
   *     "to" document to be a match. The min is inclusive. Setting min to 0 and max to <code>
   *     Interger.MAX_VALUE</code> disables the min and max "from" documents filtering
   * @param max Optionally the maximum number of "from" documents that are allowed to match for a
   *     "to" document to be a match. The max is inclusive. Setting min to 0 and max to <code>
   *     Interger.MAX_VALUE</code> disables the min and max "from" documents filtering
   * @return a {@link Query} instance that can be used to join documents based on the join field
   * @throws IOException If I/O related errors occur
   */
  public static Query createJoinQuery(
      String joinField,
      Query fromQuery,
      Query toQuery,
      IndexSearcher searcher,
      ScoreMode scoreMode,
      OrdinalMap ordinalMap,
      int min,
      int max)
      throws IOException {
    int numSegments = searcher.getIndexReader().leaves().size();
    final long valueCount;
    if (numSegments == 0) {
      return new MatchNoDocsQuery("JoinUtil.createJoinQuery with no segments");
    } else if (numSegments == 1) {
      // No need to use the ordinal map, because there is just one segment.
      ordinalMap = null;
      LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader();
      SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField);
      if (joinSortedDocValues != null) {
        valueCount = joinSortedDocValues.getValueCount();
      } else {
        return new MatchNoDocsQuery("JoinUtil.createJoinQuery: no join values");
      }
    } else {
      if (ordinalMap == null) {
        throw new IllegalArgumentException(
            "OrdinalMap is required, because there is more than 1 segment");
      }
      valueCount = ordinalMap.getValueCount();
    }

    final Query rewrittenFromQuery = searcher.rewrite(fromQuery);
    final Query rewrittenToQuery = searcher.rewrite(toQuery);
    GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector;
    switch (scoreMode) {
      case Total:
        globalOrdinalsWithScoreCollector =
            new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount, min, max);
        break;
      case Min:
        globalOrdinalsWithScoreCollector =
            new GlobalOrdinalsWithScoreCollector.Min(joinField, ordinalMap, valueCount, min, max);
        break;
      case Max:
        globalOrdinalsWithScoreCollector =
            new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount, min, max);
        break;
      case Avg:
        globalOrdinalsWithScoreCollector =
            new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount, min, max);
        break;
      case None:
        if (min <= 1 && max == Integer.MAX_VALUE) {
          GlobalOrdinalsCollector globalOrdinalsCollector =
              new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount);
          searcher.search(rewrittenFromQuery, globalOrdinalsCollector);
          return new GlobalOrdinalsQuery(
              globalOrdinalsCollector.getCollectorOrdinals(),
              joinField,
              ordinalMap,
              rewrittenToQuery,
              rewrittenFromQuery,
              searcher.getTopReaderContext().id());
        } else {
          globalOrdinalsWithScoreCollector =
              new GlobalOrdinalsWithScoreCollector.NoScore(
                  joinField, ordinalMap, valueCount, min, max);
          break;
        }
      default:
        throw new IllegalArgumentException(
            String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode));
    }
    searcher.search(rewrittenFromQuery, globalOrdinalsWithScoreCollector);
    return new GlobalOrdinalsWithScoreQuery(
        globalOrdinalsWithScoreCollector,
        scoreMode,
        joinField,
        ordinalMap,
        rewrittenToQuery,
        rewrittenFromQuery,
        min,
        max,
        searcher.getTopReaderContext().id());
  }

  /** Similar to {@link java.util.function.LongFunction} for primitive argument and result. */
  private interface LongFloatFunction {

    float apply(long value);
  }

  /** Similar to {@link java.util.function.BiConsumer} for primitive arguments. */
  private interface LongFloatProcedure {

    void apply(long key, float value);
  }
}
